Bayesian analysis for Coronary Artery Disease detection

Statistical Methods for Data Science 2 - Final project

Author

Lorenzo Pannacci - 1948926

Published

25 09 2024

# load all required libraries
library(knitr)
library(plotly)
library(ggplot2)
library(gridExtra)
library(caret)
library(dplyr)
library(corrplot)
library(R2jags)
library(ggmcmc)

Introduction and motivation

Coronary Artery Disease (CAD; in Italian “Coronaropatia”), is a type of Cardiovascular Disease (CVD) where the coronary arteries cannot deliver enough oxygen-rich blood to the heart. According to the National Heart, Lung and Blood Institute website “CVDs are the leading cause of death in the United States” (695.000 deaths per year in the US, 1 in every 5 deaths) and “coronary heart diseases are the most common type of CVD” (375.476 deaths per year in the US), with similar proportions around the whole world. In 2021 they were the world’s single biggest killer with 20.5 million deaths globally.

CAD is caused by the build up of plaque made of cholesterol in coronary arteries; this phenomenon is called atherosclerosis. The deposit can partially or totally block the flow of blood in the arteries or ease their blockage by formations of blood clots.

Infographic about how the coronary heart disease works. Image from Cleveland Clinic.

Symptoms of CAD are chest pain, heartburn and shortness of breath but their presence may differ from person to person even if they have the same type of coronary heart disease, but much more commonly it presents no symptoms at all, with the first manifest sign of the disorder being a heart attack, which itself can cause a cardiac arrest, a life-threatening medical emergency if not treated within minutes: 25% of people who have a CAD die suddenly without any previous symptom.

Given the difficulty in detection, the severity of the effects and its widespread diffusion throughout the world’s population a timely and accurate diagnosis of the disease in its early and more asymptomatic phase is extremely important for an early treatment and could save many lives every year.

The proposed study wants to use Bayesian Inference to build a model capable of detecting CAD from non-invasive clinical parameters of patients (as opposed to the coronary angiogram, described below). The Bayesian approach in this case is particularly suitable given the low amount of entries we have in the dataset that has been taken into exam.

Data

Dataset description

The data used for this project come from a publicly available dataset called “Z-Alizadeh Sani”, collected in 2012 for research purposes by Dr. Zahra Alizadeh Sani, associate professor of Cardiology at Iran University, and donated to the UC Irvine Machine Learning Repository (available here) in 2017.

A particularly positive aspect of this dataset is its completeness, there are no missing values.

The dataset contains the records of 303 individuals, random visitors to Shaheed Rajaei Cardiovascular, Medical and Research Center of Tehran. Each sample has 55 features and a target category. All the features that are recorded in the dataset have been chosen by the author due to being considered indicators of CAD according to the current medical literature. The features are arranged in four groups:

  • Demographic
  • Symptom and examination
  • ECG (electrocardiogram)
  • Laboratory and echo

The ground truth classification is the result of a Coronary Angiogram made on the patient, an imaging technique used to visualize blood vessels, arteries and veins using X-rays and a radio-opaque contrast agent inserted in the blood flow. While accurate this method is rarely performed due to its high cost and the invasiveness of the procedure.

Angiogram of the heart showing a severe coronary narrowing. Image from Capital Heart Centre.

Each patient is in one of two possible categories: CAD or Normal. A patient is categorized as CAD if the diameter narrowing on an artery is greater than or equal to 50% and is otherwise categorized as Normal. Over the total of 303 individuals 216 samples have CAD and the rest are healthy.

The following table summarize the features of the dataset, their meaning and the values they take:

Feature Type Feature Name Range
Demographic Age (years) 30–86
Weight (kg) 48–120
Length (height, cm) 140–188
Sex male, female
BMI (body mass index, kg/m²) 18–41
DM (history of diabetes mellitus) yes, no (binary)
HTN (history of hyper tension) yes, no (binary)
Current smoker yes, no (binary)
Ex Smoker yes, no (binary)
FH (history of CVD in first-degree relatives) yes, no (binary)
Obesity yes, no (string)
CRF (chronic renal failure) yes, no (string)
CVA (cerebrovascular accident) yes, no (string)
Airway disease yes, no (string)
Thyroid Disease yes, no (string)
CHF (congestive heart failure) yes, no (string)
DLP (dyslipidemia, high lipids in blood) yes, no (string)
Symptom and Examination BP (blood pressure, mmHg) 90–190
PR (pulse rate, ppm) 50–110
Edema (fluid retention in body tissue) yes, no (binary)
Weak peripheral pulse yes, no (string)
Lung rales (abnormal lung sounds) yes, no (string)
Systolic murmur yes, no (string)
Diastolic murmur yes, no (string)
Typical Chest Pain yes, no (binary)
Dyspnea (shortness of breath) yes, no (string)
Function class (frequency of symptoms) 1, 2, 3, 4
Atypical yes, no (string)
Nonanginal CP (chest pain at rest) yes, no (string)
Exertional CP (chest Pain during physical exertion) yes, no (string)
Low Th Ang (low threshold angina) yes, no (string)
ECG (electrocardiogram) Q Wave yes, no (binary)
ST Elevation yes, no (binary)
ST Depression yes, no (binary)
T inversion yes, no (binary)
LVH (left ventricular hypertrophy) yes, no (string)
Poor R progression yes, no (string)
BBB (bundle branch block) no, left, right
Laboratory and Echo FBS (fasting blood sugar, mg/dl) 62–400
Cr (creatine, mg/dl) 0.5–2.2
TG (triglyceride, mg/dl) 37–1050
LDL (low density lipoprotein, mg/dl) 18–232
HDL (high density lipoprotein, mg/dl) 15–111
BUN (blood urea nitrogen, mg/dl) 6–52
ESR (erythrocyte sedimentation rate, mm/h) 1–90
HB (hemoglobin, g/dl) 8.9–17.6
K (potassium, mEq/lit) 3.0–6.6
Na (sodium, mEq/lit) 128–156
WBC (white blood cell, cells/ml) 3700–18000
Lymph (lymphocyte, %) 7–60
Neut (neutrophil, %) 32–89
PLT (platelet, 1000/ml) 25–742
EF (ejection fraction, %) 9–65
Region with RWMA (regional wall motion abnormality) 0, 1, 2, 3, 4
VHD (valvular heart disease) normal, mild, moderate, severe
Target Cath (cardiac catheterization) cad, Normal

Preprocessing and cleaning

To get an idea of the raw data we provide some example entries in the dataset:

# read CSV file
data = read.csv("Z-Alizadeh Sani dataset.csv")

# print head
knitr::kable(head(data, 5), col.names = gsub("[.]", " ", names(data)))
Age Weight Length Sex BMI DM HTN Current Smoker EX Smoker FH Obesity CRF CVA Airway disease Thyroid Disease CHF DLP BP PR Edema Weak Peripheral Pulse Lung rales Systolic Murmur Diastolic Murmur Typical Chest Pain Dyspnea Function Class Atypical Nonanginal Exertional CP LowTH Ang Q Wave St Elevation St Depression Tinversion LVH Poor R Progression BBB FBS CR TG LDL HDL BUN ESR HB K Na WBC Lymph Neut PLT EF TTE Region RWMA VHD Cath
53 90 175 Male 29.38776 0 1 1 0 0 Y N N N N N Y 110 80 0 N N N N 0 N 0 N N N N 0 0 1 1 N N N 90 0.7 250 155 30 8 7 15.6 4.7 141 5700 39 52 261 50 0 N Cad
67 70 157 Fmale 28.39872 0 1 0 0 0 Y N N N N N N 140 80 1 N N N N 1 N 0 N N N N 0 0 1 1 N N N 80 1.0 309 121 36 30 26 13.9 4.7 156 7700 38 55 165 40 4 N Cad
54 54 164 Male 20.07733 0 0 1 0 0 N N N N N N N 100 100 0 N N N N 1 N 0 N N N N 0 0 0 0 N N N 85 1.0 103 70 45 17 10 13.5 4.7 139 7400 38 60 230 40 2 mild Cad
66 67 158 Fmale 26.83865 0 1 0 0 0 Y N N N N N N 100 80 0 N N N Y 0 Y 3 N Y N N 0 0 1 0 N N N 78 1.2 63 55 27 30 76 12.1 4.4 142 13000 18 72 742 55 0 Severe Normal
50 87 153 Fmale 37.16519 0 1 0 0 0 Y N N N N N N 110 80 0 N N Y N 0 Y 2 N N N N 0 0 0 0 N N N 104 1.0 170 110 50 16 27 13.2 4.0 140 9200 55 39 274 50 0 Severe Normal


Since the features are represented in many different formats the first thing we have to do is to perform an operation of data cleaning and preparation. Binary features in the string form yes/no become booleans and the same happens for the Sex and the target variable Cath.

The feature VHD is populated with 4 different kind of strings, we can easily convert them into the integers 0-3 since they imply sequentiality.

Also the feature BBB contains multiple strings but since there is no sequentiality between the values we have to perform a one-hot encoding conversion, creating more variables.

Moreover we removed the feature Exertional CP since all its entries have the same value. This is how the same rows of before appear after the preprocessing:

Age Weight Length Sex BMI DM HTN Current Smoker EX Smoker FH Obesity CRF CVA Airway disease Thyroid Disease CHF DLP BP PR Edema Weak Peripheral Pulse Lung rales Systolic Murmur Diastolic Murmur Typical Chest Pain Dyspnea Function Class Atypical Nonanginal LowTH Ang Q Wave St Elevation St Depression Tinversion LVH Poor R Progression FBS CR TG LDL HDL BUN ESR HB K Na WBC Lymph Neut PLT EF TTE Region RWMA VHD Cath LBBB RBBB
53 90 175 1 29.38776 0 1 1 0 0 1 0 0 0 0 0 1 110 80 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 90 0.7 250 155 30 8 7 15.6 4.7 141 5700 39 52 261 50 0 0 Cad 0 0
67 70 157 0 28.39872 0 1 0 0 0 1 0 0 0 0 0 0 140 80 1 0 0 0 0 1 0 0 0 0 0 0 0 1 1 0 0 80 1.0 309 121 36 30 26 13.9 4.7 156 7700 38 55 165 40 4 0 Cad 0 0
54 54 164 1 20.07733 0 0 1 0 0 0 0 0 0 0 0 0 100 100 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 85 1.0 103 70 45 17 10 13.5 4.7 139 7400 38 60 230 40 2 1 Cad 0 0
66 67 158 0 26.83865 0 1 0 0 0 1 0 0 0 0 0 0 100 80 0 0 0 0 1 0 1 3 0 1 0 0 0 1 0 0 0 78 1.2 63 55 27 30 76 12.1 4.4 142 13000 18 72 742 55 0 3 Normal 0 0
50 87 153 0 37.16519 0 1 0 0 0 1 0 0 0 0 0 0 110 80 0 0 0 1 0 0 1 2 0 0 0 0 0 0 0 0 0 104 1.0 170 110 50 16 27 13.2 4.0 140 9200 55 39 274 50 0 3 Normal 0 0

Exploratory data analysis

An exploratory data analysis can give us a lot of information on how data behave and what is the difference in distribution between the two target classes that can help us identify them, as well as showing which features would just introduce noise and would be better to leave behind. We can experiment with many different plots, let’s start by plotting the features into histograms:

Despite the low amount of entries we can observe that many continuous variables take approximately the shape of a Normal distribution, while for example EF TTE seem to take a shape more similar to a Beta distribution. The histograms for the discrete variables don’t give us much information in this form, but we can observe a great unbalance in most of the features.

Another interesting observation that can be made with plots is how the behavior of the distributions changes when conditioned on the target class:

From those plots we can infer some interesting patterns: CAD tends to be correlated with a higher age, an increased blood pressure and a higher fasting blood sugar. Males seems to be more affected than females, as well as diabetic people and there are some binary values that take a positive response exclusively in the case of CAD, in particular those inside the electrocardiogram subset of features.

Particularly interesting is the plot of Typical chest pain, where we can see that about 25% of patients with CAD do not experience them, a statistic that fits perfectly with what we found online and reported in the introduction.

In the data we also find some unusual behaviors: for all features that are linked to weight as Weight, BMI, HDL and Obesity there is little to no difference between the distributions for the two classes, while common knowledge says that being overweight leads to a major increase in heart diseases occurrences. We can interpret this observation as a bias of the dataset given by how data has been gathered: since all observations are visitors of a medical facility specialized in cardiovascular diseases the samples recorded are not representative of the overall population. This reasoning has brought to the counterintuitive decision of excluding those kind of variables from the study.

Finally we can analyze the linear correlation between the variables with a correlation plot. Given the high amount of variables it would be messy to create a complete correlation matrix, for this reason we decided to use only the top 10 most correlated features:

Goals and measures

Since the target variable of this study is binary what the project aims to realize is a Bayesian logistic regression. As said before the Bayesian approach is particularly suitable to this problem since it gives the ability to incorporate prior information in our model and is more robust than a frequentist approach in scenarios where few samples are available.

Our aim in terms of model performances is to get the best precision and recall scores as possible, but as for many other medical models, we want to prioritize a high recall because in disease detection models it is vital to identify as many true cases as possible since false positives can be resolved simply with further testing while false negatives may endanger the patient.

To test for performances and prevent the measurements from being positively distorted by overfitting we divide the dataset into two subsets: a train set used for fitting and a test set of never-before-seen data to test the inference capabilities of the model. We use the classic proportion of 75% for train and 25% for test and set the random seed to ensure the replicability of the split.

# set seed for reproducibility
set.seed(123)

train_indices = sample(seq_len(nrow(data)), size = 0.75 * nrow(data))

# split into train and test sets
train_data = data[train_indices, ]
test_data = data[-train_indices, ]

Modelling the Bayesian problem

The model we want to build is a logistic regression. In this model we use our set of features \(X = (X_1, ..., X_n)\) to infer the value of a target binary categorical variable:

\[ Y = \left\{ \begin{array}{cl} 1 & \text{CAD} \\ 0 & \text{Normal} \end{array} \right. \]

Differently from a linear regression a logistic regression models as a linear combination of the parameters not a scalar value but the probability of an event measured in the log-odds, which are defined as \(\phi = ln\left( \frac{p}{1-p} \right)\), where \(\frac{p}{1-p}\) is the definition of odds and \(p \in [0,1]\) is the probability of the event. Therefore the regression we are modelling is:

\[ \phi = ln\left( \frac{p}{1-p} \right) = \beta_0 + \beta_1 X_1 + ...+ \beta_n X_n \]

During the inference step we just reverse the formula to obtain the probability:

\[ p = \frac{exp(\phi)}{1 + exp(\phi)} = \frac{1}{1 + exp(- \phi)} \]

And as prediction we just take the class with the highest probability:

\[ \text{pred} = \left\{ \begin{array}{cl} 1 & \text{if } p \ge 0.5 \\ 0 & \text{otherwise} \end{array} \right. \]

The logistic regression is just one of the forms of the Generalized Linear Model, in particular it is what we obtain if the target is distributed as a Bernoulli random variable and the log-odds are what we get as its link function.

Once we have the structure of the model we have to use the training data to infer good values for the \(\beta\) parameters.

Models

1. Naive model (baseline)

Our fist model, that we call “naive”, is a model that takes every feature defined in the dataset and uses them for the logistic regression. We can use this model as a baseline for the others. As priors we decided to use Weakly Informative Priors, a normal distribution with zero mean and high variance.

jags_code = "model{

  ##########
  # PRIORS #
  ##########
  
  # for the intercept we use a normal with mean 0 and precision 0.01
  beta0 ~ dnorm(0, 0.01)  

  # priors for the coefficients
  # for those we use a normal with mean 0 and precision 0.4
  
  for (j in 1:n_features){
    beta[j] ~ dnorm(0, 0.4)  
  }
  
  ##############
  # LIKELIHOOD #
  ##############
  
  for (i in 1:n_samples){
    # calculate logits (log-odds)
    logit_p[i] = beta0 + inprod(beta[1:n_features], x[i,])
    
    # convert log-odds into probabilities
    p[i] = 1 / (1 + exp(- logit_p[i]))
    
    # get the binary outcome for the target
    y[i] ~ dbern(p[i])
  }

}"

features = setdiff(colnames(data), exclude)

# pass parameters to format for JAGS
model_data = list(
  x = as.matrix(train_data[, features]),
  y = train_data$Cath,
  n_samples = nrow(train_data),
  n_features = length(features)
)

jags_model1 = jags(model.file=textConnection(jags_code),
                  data = model_data, 
                  inits = NULL,
                  n.chains = 5,
                  n.iter = 15000,
                  n.burnin = 5000,
                  parameters.to.save = c("beta0", "beta"))
Compiling model graph
   Resolving undeclared variables
   Allocating nodes
Graph information:
   Observed stochastic nodes: 227
   Unobserved stochastic nodes: 53
   Total graph size: 13680

Initializing model
jags_model1
Inference for Bugs model at "4", fit using jags,
 5 chains, each with 15000 iterations (first 5000 discarded), n.thin = 10
 n.sims = 5000 iterations saved
         mu.vect sd.vect    2.5%    25%     50%     75%   97.5%  Rhat n.eff
beta[1]    0.126   0.039   0.056  0.099   0.125   0.151   0.207 1.017   210
beta[2]    0.686   0.840  -0.935  0.111   0.659   1.252   2.353 1.003  1100
beta[3]    1.631   0.897  -0.095  1.029   1.629   2.256   3.343 1.001  5000
beta[4]    1.388   0.817  -0.204  0.828   1.386   1.935   2.981 1.005   780
beta[5]    0.408   0.828  -1.179 -0.160   0.399   0.962   2.096 1.002  2100
beta[6]    0.468   1.353  -2.250 -0.437   0.450   1.364   3.168 1.002  3000
beta[7]    1.924   0.872   0.185  1.345   1.918   2.498   3.639 1.003  1400
beta[8]   -0.405   0.722  -1.829 -0.895  -0.387   0.076   0.980 1.001  3000
beta[9]    0.102   1.573  -2.994 -0.949   0.105   1.176   3.191 1.001  4900
beta[10]   0.567   1.299  -1.984 -0.306   0.572   1.469   3.110 1.001  5000
beta[11]   1.036   1.293  -1.491  0.146   1.036   1.906   3.524 1.001  5000
beta[12]  -0.001   1.298  -2.526 -0.861  -0.003   0.858   2.550 1.001  3500
beta[13]  -0.036   1.584  -3.215 -1.104  -0.037   1.003   3.010 1.001  5000
beta[14]   0.044   0.679  -1.255 -0.420   0.041   0.499   1.378 1.001  5000
beta[15]   0.040   0.026  -0.011  0.023   0.039   0.057   0.091 1.008   850
beta[16]   0.088   0.051  -0.013  0.057   0.087   0.121   0.189 1.020   320
beta[17]  -0.326   1.242  -2.700 -1.179  -0.334   0.512   2.075 1.002  2500
beta[18]   0.196   1.529  -2.780 -0.813   0.176   1.227   3.244 1.001  5000
beta[19]   1.305   1.325  -1.281  0.410   1.286   2.205   3.908 1.002  2200
beta[20]   0.293   1.074  -1.752 -0.431   0.276   1.013   2.395 1.002  1600
beta[21]  -0.366   1.242  -2.789 -1.211  -0.351   0.470   2.045 1.001  5000
beta[22]   3.444   0.896   1.732  2.844   3.444   4.046   5.209 1.001  5000
beta[23]  -2.403   0.784  -3.980 -2.921  -2.379  -1.863  -0.924 1.009   350
beta[24]   0.688   0.404  -0.101  0.409   0.689   0.963   1.472 1.005   690
beta[25]  -0.471   0.888  -2.229 -1.072  -0.476   0.145   1.283 1.001  5000
beta[26]  -1.377   1.133  -3.579 -2.157  -1.377  -0.613   0.859 1.001  5000
beta[27]  -0.011   1.595  -3.122 -1.070  -0.024   1.058   3.100 1.001  5000
beta[28]   0.679   1.380  -2.031 -0.248   0.661   1.599   3.392 1.002  2100
beta[29]   1.243   1.332  -1.336  0.380   1.209   2.114   3.929 1.001  5000
beta[30]   1.381   0.926  -0.368  0.729   1.361   2.012   3.201 1.006   580
beta[31]   1.984   0.846   0.329  1.428   1.982   2.547   3.658 1.001  4100
beta[32]   0.527   1.098  -1.640 -0.208   0.523   1.273   2.689 1.002  2500
beta[33]   0.254   1.489  -2.623 -0.756   0.241   1.275   3.174 1.001  5000
beta[34]   0.003   0.009  -0.015 -0.003   0.003   0.009   0.021 1.001  5000
beta[35]   0.548   1.179  -1.784 -0.237   0.553   1.312   2.894 1.002  2400
beta[36]   0.013   0.006   0.003  0.009   0.013   0.017   0.025 1.003  1300
beta[37]   0.002   0.012  -0.020 -0.006   0.002   0.010   0.026 1.005   640
beta[38]   0.011   0.028  -0.042 -0.007   0.011   0.030   0.068 1.008   390
beta[39]  -0.036   0.064  -0.165 -0.079  -0.035   0.007   0.089 1.006   540
beta[40]   0.003   0.029  -0.054 -0.016   0.003   0.023   0.061 1.001  4200
beta[41]  -0.387   0.305  -1.035 -0.585  -0.364  -0.171   0.148 1.050    74
beta[42]   1.355   0.827  -0.239  0.792   1.323   1.927   3.006 1.019   220
beta[43]  -0.120   0.066  -0.273 -0.155  -0.116  -0.073  -0.012 1.219    21
beta[44]   0.000   0.000   0.000  0.000   0.000   0.000   0.001 1.004   990
beta[45]  -0.068   0.086  -0.229 -0.127  -0.066  -0.016   0.120 1.030   130
beta[46]  -0.050   0.082  -0.193 -0.110  -0.050   0.005   0.128 1.043   110
beta[47]  -0.005   0.007  -0.020 -0.010  -0.005  -0.001   0.008 1.011   290
beta[48]  -0.052   0.052  -0.156 -0.087  -0.051  -0.016   0.049 1.010   310
beta[49]   2.753   0.885   1.194  2.134   2.700   3.313   4.641 1.001  3900
beta[50]  -0.724   0.605  -1.941 -1.119  -0.709  -0.320   0.459 1.001  3900
beta[51]  -0.984   1.211  -3.389 -1.804  -0.954  -0.172   1.370 1.002  1900
beta[52]  -0.231   1.198  -2.607 -1.048  -0.218   0.577   2.092 1.002  1800
beta0      0.566   7.456 -15.661 -4.146   0.465   5.425  14.725 1.086    42
deviance 103.363   8.396  88.288 97.459 102.775 108.812 120.908 1.007   460

For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).

DIC info (using the rule, pD = var(deviance)/2)
pD = 35.0 and DIC = 138.3
DIC is an estimate of expected predictive error (lower deviance is better).
# posterior samples of coefficients
beta_samples = as.data.frame(ggs(as.mcmc(jags_model1)))

# calculate the mean for each parameter (beta)
beta_means = beta_samples %>%
  group_by(Parameter) %>%
  summarize(mean_value = mean(value)) %>%
  spread(key = Parameter, value = mean_value)

# prepare the test data with all predictors
model_test_data = as.matrix(test_data[, features])

# calculate predicted probabilities using posterior means
logit_prediction = beta_means$beta0 + 
  model_test_data %*% as.numeric(beta_means[grep("beta(?!0)", names(beta_means), perl = TRUE)])

# convert log-odds to probabilities
pred_probs = 1 / (1 + exp(- logit_prediction))

# convert probabilities to class predictions
predictions = ifelse(pred_probs >= 0.5, 1, 0)

# calculate metrics
accuracy = sum(predictions == test_data$Cath) / nrow(test_data)
recall = sum(predictions == 1 & test_data$Cath == 1) / sum(test_data$Cath == 1)
precision = sum(predictions == 1 & test_data$Cath == 1) / sum(predictions == 1)
f1_score = 2 * (precision * recall) / (precision + recall)

# print metrics
print(paste("Accuracy: ", accuracy))
[1] "Accuracy:  0.842105263157895"
print(paste("Recall: ", recall))
[1] "Recall:  0.924528301886792"
print(paste("Precision: ", precision))
[1] "Precision:  0.859649122807018"
print(paste("F1 Score: ", f1_score))
[1] "F1 Score:  0.890909090909091"

2. Model with feature selection

The second model is very similar to the naive one but some features that have been noticed to introduce noise are removed. What we obtain should be a slightly simpler model and this should help convergence.

exclude2 = c("Cath", "Weight", "Length", "BMI", "Atypical", "Nonanginal", "FBS", "Diastolic.Murmur", "Current.Smoker", "EX.Smoker")
features = setdiff(colnames(data), exclude2)

# pass parameters to format for JAGS
model_data = list(
  x = as.matrix(train_data[, features]),
  y = train_data$Cath,
  n_samples = nrow(train_data),
  n_features = length(features)
)

jags_model2 = jags(model.file=textConnection(jags_code),
                  data = model_data, 
                  inits = NULL,
                  n.chains = 5,
                  n.iter = 15000,
                  n.burnin = 5000,
                  parameters.to.save = c("beta0", "beta"))
Compiling model graph
   Resolving undeclared variables
   Allocating nodes
Graph information:
   Observed stochastic nodes: 227
   Unobserved stochastic nodes: 47
   Total graph size: 12312

Initializing model
jags_model2
Inference for Bugs model at "5", fit using jags,
 5 chains, each with 15000 iterations (first 5000 discarded), n.thin = 10
 n.sims = 5000 iterations saved
         mu.vect sd.vect    2.5%     25%     50%     75%   97.5%  Rhat n.eff
beta[1]    0.119   0.039   0.047   0.091   0.117   0.144   0.201 1.007  1100
beta[2]    0.849   0.813  -0.705   0.302   0.827   1.390   2.501 1.011   280
beta[3]    1.865   0.777   0.382   1.341   1.852   2.386   3.409 1.001  3800
beta[4]    1.212   0.774  -0.289   0.686   1.218   1.717   2.726 1.006   540
beta[5]    1.844   0.853   0.204   1.261   1.840   2.424   3.532 1.004   900
beta[6]   -0.438   0.715  -1.839  -0.922  -0.435   0.040   0.976 1.001  3500
beta[7]    0.137   1.539  -2.840  -0.913   0.158   1.151   3.187 1.002  2000
beta[8]    0.611   1.267  -1.846  -0.258   0.612   1.442   3.133 1.002  2300
beta[9]    1.024   1.268  -1.424   0.164   1.018   1.876   3.497 1.002  2600
beta[10]  -0.071   1.252  -2.472  -0.936  -0.077   0.768   2.432 1.001  5000
beta[11]  -0.003   1.556  -3.041  -1.033   0.002   1.022   3.035 1.001  5000
beta[12]   0.038   0.662  -1.255  -0.415   0.040   0.487   1.335 1.004   940
beta[13]   0.038   0.023  -0.006   0.022   0.037   0.053   0.084 1.024   140
beta[14]   0.086   0.051  -0.016   0.052   0.087   0.118   0.190 1.021   230
beta[15]  -0.242   1.204  -2.622  -1.052  -0.234   0.572   2.121 1.002  2000
beta[16]   0.167   1.484  -2.686  -0.854   0.173   1.167   3.135 1.002  2100
beta[17]   1.180   1.298  -1.364   0.290   1.167   2.046   3.745 1.001  4600
beta[18]   0.197   1.029  -1.821  -0.503   0.202   0.889   2.213 1.005   670
beta[19]   3.859   0.674   2.583   3.390   3.850   4.304   5.225 1.004   940
beta[20]  -2.313   0.766  -3.851  -2.837  -2.313  -1.792  -0.840 1.012   270
beta[21]   0.670   0.384  -0.087   0.412   0.669   0.921   1.443 1.002  2700
beta[22]  -0.009   1.576  -3.069  -1.062  -0.022   1.061   3.104 1.002  1900
beta[23]   0.701   1.380  -1.961  -0.250   0.665   1.578   3.547 1.001  3100
beta[24]   1.246   1.325  -1.337   0.348   1.238   2.119   3.902 1.002  1700
beta[25]   1.295   0.922  -0.554   0.676   1.300   1.909   3.080 1.014   230
beta[26]   1.947   0.833   0.363   1.407   1.934   2.489   3.623 1.004   950
beta[27]   0.784   1.102  -1.363   0.031   0.780   1.527   2.908 1.005   660
beta[28]   0.338   1.471  -2.474  -0.671   0.337   1.327   3.276 1.003  1300
beta[29]   0.622   1.130  -1.625  -0.114   0.634   1.367   2.818 1.007   540
beta[30]   0.012   0.005   0.003   0.009   0.012   0.016   0.023 1.009   370
beta[31]   0.001   0.011  -0.021  -0.007   0.001   0.008   0.023 1.011   290
beta[32]   0.013   0.028  -0.042  -0.006   0.012   0.031   0.069 1.013   250
beta[33]  -0.029   0.060  -0.147  -0.069  -0.028   0.011   0.087 1.008   400
beta[34]   0.003   0.029  -0.053  -0.016   0.003   0.023   0.063 1.016   200
beta[35]  -0.373   0.315  -1.032  -0.588  -0.356  -0.151   0.223 1.085    44
beta[36]   1.496   0.816  -0.108   0.924   1.498   2.088   3.063 1.045    72
beta[37]  -0.100   0.091  -0.311  -0.159  -0.097  -0.032   0.059 1.251    19
beta[38]   0.000   0.000   0.000   0.000   0.000   0.000   0.001 1.011   320
beta[39]  -0.038   0.092  -0.223  -0.103  -0.039   0.028   0.134 1.015   210
beta[40]  -0.020   0.090  -0.197  -0.084  -0.017   0.048   0.143 1.023   150
beta[41]  -0.007   0.007  -0.020  -0.011  -0.006  -0.002   0.006 1.022   150
beta[42]  -0.034   0.055  -0.145  -0.071  -0.033   0.005   0.072 1.030   120
beta[43]   2.709   0.863   1.188   2.111   2.663   3.240   4.545 1.004   900
beta[44]  -0.613   0.590  -1.777  -1.000  -0.603  -0.213   0.525 1.003  1400
beta[45]  -0.963   1.181  -3.328  -1.747  -0.947  -0.163   1.313 1.002  2500
beta[46]  -0.163   1.152  -2.458  -0.923  -0.155   0.597   2.071 1.002  2500
beta0     -5.867   7.724 -21.395 -11.500  -5.116  -0.111   8.622 1.139    27
deviance 101.437   8.071  87.288  95.677 100.966 106.641 118.863 1.014   220

For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).

DIC info (using the rule, pD = var(deviance)/2)
pD = 32.0 and DIC = 133.4
DIC is an estimate of expected predictive error (lower deviance is better).
# posterior samples of coefficients
beta_samples = as.data.frame(ggs(as.mcmc(jags_model2)))

# calculate the mean for each parameter (beta)
beta_means = beta_samples %>%
  group_by(Parameter) %>%
  summarize(mean_value = mean(value)) %>%
  spread(key = Parameter, value = mean_value)

# prepare the test data with all predictors
model_test_data = as.matrix(test_data[, features])

# calculate predicted probabilities using posterior means
logit_prediction = beta_means$beta0 + 
  model_test_data %*% as.numeric(beta_means[grep("beta(?!0)", names(beta_means), perl = TRUE)])

# convert log-odds to probabilities
pred_probs = 1 / (1 + exp(- logit_prediction))

# convert probabilities to class predictions
predictions = ifelse(pred_probs >= 0.5, 1, 0)

# calculate metrics
accuracy = sum(predictions == test_data$Cath) / nrow(test_data)
recall = sum(predictions == 1 & test_data$Cath == 1) / sum(test_data$Cath == 1)
precision = sum(predictions == 1 & test_data$Cath == 1) / sum(predictions == 1)
f1_score = 2 * (precision * recall) / (precision + recall)

# print metrics
print(paste("Accuracy: ", accuracy))
[1] "Accuracy:  0.842105263157895"
print(paste("Recall: ", recall))
[1] "Recall:  0.924528301886792"
print(paste("Precision: ", precision))
[1] "Precision:  0.859649122807018"
print(paste("F1 Score: ", f1_score))
[1] "F1 Score:  0.890909090909091"

3. Model with feature engineering

This third model takes inspiration from what has been used in the original Dr. Alizadeh’s paper about this dataset. A discretization of some variables has been executed differencing levels in “Low”, “Normal” and “High”. In particular some features have differences depending on the sex of the patient.

Feature Low Normal High
Cr Cr < 0.7 0.7 ≤ Cr ≤ 1.5 Cr > 1.5
FBS FBS < 70 70 ≤ FBS ≤ 105 FBS > 105
LDL LDL < 130 LDL > 130
HDL HDL < 35 HDL ≥ 35
BUN BUN < 7 7 ≤ BUN ≤ 20 BUN > 20
ESR If male and ESR ≤ age/2, if female and ESR ≤ age/2 + 5 If male and ESR > age/2, if female and ESR > age/2 + 5
HB If male and HB < 14, if female and HB < 12.5 If male and 14 ≤ HB ≤ 17, if female and 12.5 ≤ HB ≤ 15 If male and HB > 17, if female and HB > 15
K K < 3.8 3.8 ≤ K ≤ 5.6 K > 5.6
Na Na < 136 136 ≤ Na ≤ 146 Na > 146
WBC WBC < 4000 4000 ≤ WBC ≤ 11,000 WBC > 11,000
PLT PLT < 150 150 ≤ PLT ≤ 450 PLT > 450
EF EF ≤ 50 EF > 50
Region with RWMA Region with RWMA = 0 Region with RWMA ≠ 0
Age If male and age ≤ 45, if female and age ≤ 55 If male and age > 45, if female and age > 55
BP BP < 90 90 ≤ BP ≤ 140 BP > 140
PulseRate PulseRate < 60 60 ≤ PulseRate ≤ 100 PulseRate > 100
TG TG < 200 TG ≥ 200
# discretive features
data$CR = as.numeric(cut(data$CR, breaks=c(-Inf, 0.7, 1.5, Inf), labels=c(0, 1, 2)))
data$FBS = as.numeric(cut(data$FBS, breaks=c(-Inf, 70, 105, Inf), labels=c(0, 1, 2)))
data$LDL = as.numeric(cut(data$LDL, breaks=c(-Inf, 130, Inf), labels=c(1, 2)))
data$HDL = as.numeric(cut(data$HDL, breaks=c(-Inf, 35, Inf), labels=c(0, 1)))
data$BUN = as.numeric(cut(data$BUN, breaks=c(-Inf, 7, 20, Inf), labels=c(0, 1, 2)))
data$K = as.numeric(cut(data$K, breaks=c(-Inf, 3.8, 5.6, Inf), labels=c(0, 1, 2)))
data$Na = as.numeric(cut(data$Na, breaks=c(-Inf, 136, 146, Inf), labels=c(0, 1, 2)))
data$WBC = as.numeric(cut(data$WBC, breaks=c(-Inf, 4000, 11000, Inf), labels=c(0, 1, 2)))
data$PLT = as.numeric(cut(data$PLT, breaks=c(-Inf, 150, 450, Inf), labels=c(0, 1, 2)))
data$EF.TTE = as.numeric(cut(data$EF.TTE, breaks=c(-Inf, 50, Inf), labels=c(0, 1)))
data$BP = as.numeric(cut(data$BP, breaks=c(-Inf, 90, 140, Inf), labels=c(0, 1, 2)))
data$PR = as.numeric(cut(data$PR, breaks=c(-Inf, 60, 100, Inf), labels=c(0, 1, 2)))
data$TG = as.numeric(cut(data$TG, breaks=c(-Inf, 200, Inf), labels=c(1, 2)))
data$Function.Class = as.numeric(cut(data$Function.Class, breaks=c(-Inf, 1.5, Inf), labels=c(1, 2)))
data$Region.RWMA = as.numeric(ifelse(data$Region.RWMA == 0, 1, 2))
data$ESR = as.numeric(with(data, ifelse((Sex == 1 & ESR <= Age/2) | (Sex == 0 & ESR <= Age/2 + 5), 1, 2)))
data$HB = as.numeric(with(data, ifelse((Sex == 1 & HB < 14) | (Sex == 0 & HB < 12.5), 0, ifelse((Sex == 1 & HB <= 17) | (Sex == 0 & HB <= 15), 1, 2))))
data$Age = as.numeric(with(data, ifelse((Sex == 1 & Age > 45) | (Sex == 0 & Age > 55), 2, 1)))

# update train and test datasets
train_data_discrete = data[train_indices, ]
test_data_discrete = data[-train_indices, ]

features = setdiff(colnames(data), exclude)

# pass parameters to format for JAGS
model_data = list(
  x = as.matrix(train_data_discrete[, features]),
  y = train_data_discrete$Cath,
  n_samples = nrow(train_data_discrete),
  n_features = length(features)
)

jags_model3 = jags(model.file=textConnection(jags_code),
                  data = model_data, 
                  inits = NULL,
                  n.chains = 5,
                  n.iter = 15000,
                  n.burnin = 5000,
                  parameters.to.save = c("beta0", "beta"))
Compiling model graph
   Resolving undeclared variables
   Allocating nodes
Graph information:
   Observed stochastic nodes: 227
   Unobserved stochastic nodes: 53
   Total graph size: 13680

Initializing model
jags_model3
Inference for Bugs model at "6", fit using jags,
 5 chains, each with 15000 iterations (first 5000 discarded), n.thin = 10
 n.sims = 5000 iterations saved
         mu.vect sd.vect    2.5%     25%     50%     75%   97.5%  Rhat n.eff
beta[1]    2.252   0.726   0.881   1.760   2.247   2.733   3.726 1.005   660
beta[2]    0.010   0.668  -1.295  -0.444   0.010   0.458   1.310 1.001  5000
beta[3]    2.192   0.786   0.656   1.654   2.195   2.733   3.727 1.002  2000
beta[4]    1.949   0.650   0.682   1.517   1.934   2.379   3.242 1.002  2600
beta[5]   -0.034   0.769  -1.557  -0.555  -0.031   0.479   1.463 1.001  5000
beta[6]    1.010   1.241  -1.365   0.169   1.007   1.838   3.510 1.001  4200
beta[7]    1.493   0.832  -0.073   0.931   1.491   2.030   3.154 1.001  3600
beta[8]   -0.990   0.664  -2.310  -1.431  -0.985  -0.557   0.338 1.001  5000
beta[9]    0.104   1.519  -2.720  -0.925   0.059   1.128   3.042 1.001  3100
beta[10]  -0.195   1.268  -2.640  -1.036  -0.221   0.679   2.323 1.001  5000
beta[11]   1.172   1.186  -1.104   0.366   1.203   1.983   3.482 1.001  3300
beta[12]  -0.539   1.208  -2.864  -1.357  -0.552   0.263   1.884 1.001  3900
beta[13]   0.233   1.499  -2.677  -0.779   0.222   1.249   3.144 1.001  5000
beta[14]  -0.466   0.605  -1.632  -0.867  -0.465  -0.071   0.729 1.003  1100
beta[15]  -0.526   0.849  -2.220  -1.108  -0.508   0.071   1.071 1.004  1000
beta[16]   1.600   1.083  -0.376   0.855   1.543   2.271   4.008 1.004  1600
beta[17]   0.346   1.099  -1.748  -0.409   0.345   1.075   2.551 1.001  5000
beta[18]   0.569   1.419  -2.143  -0.398   0.538   1.489   3.422 1.003  1500
beta[19]   1.165   1.194  -1.129   0.346   1.172   1.986   3.476 1.001  5000
beta[20]   1.250   0.894  -0.474   0.656   1.255   1.826   3.042 1.001  5000
beta[21]  -0.611   1.183  -2.991  -1.407  -0.573   0.187   1.589 1.001  4500
beta[22]   3.495   0.833   1.921   2.923   3.473   4.041   5.169 1.003  1400
beta[23]  -1.562   0.659  -2.867  -2.010  -1.565  -1.108  -0.281 1.001  5000
beta[24]   0.818   0.714  -0.555   0.332   0.806   1.295   2.227 1.003  1200
beta[25]  -0.394   0.796  -1.962  -0.935  -0.397   0.155   1.157 1.001  3100
beta[26]  -1.246   0.998  -3.205  -1.917  -1.245  -0.593   0.728 1.002  2900
beta[27]   0.022   1.549  -3.015  -1.026   0.034   1.074   3.069 1.001  5000
beta[28]   1.263   1.301  -1.186   0.385   1.241   2.090   3.920 1.001  3600
beta[29]   1.123   1.289  -1.356   0.232   1.111   1.985   3.719 1.001  3200
beta[30]   1.289   0.843  -0.337   0.719   1.287   1.859   2.959 1.004   910
beta[31]   1.790   0.759   0.322   1.285   1.763   2.292   3.280 1.002  2900
beta[32]  -0.015   0.987  -1.836  -0.684  -0.035   0.625   1.957 1.001  3300
beta[33]   0.357   1.444  -2.416  -0.604   0.320   1.304   3.282 1.002  1800
beta[34]   0.274   0.664  -1.001  -0.178   0.270   0.716   1.610 1.007   490
beta[35]  -0.963   0.872  -2.686  -1.567  -0.966  -0.352   0.709 1.005   740
beta[36]   0.707   0.759  -0.746   0.189   0.713   1.212   2.183 1.004   790
beta[37]   0.697   0.729  -0.710   0.213   0.689   1.187   2.154 1.003  1400
beta[38]   0.103   0.652  -1.193  -0.328   0.102   0.548   1.364 1.003  1500
beta[39]   0.116   0.664  -1.169  -0.333   0.115   0.553   1.430 1.004  1100
beta[40]  -0.653   0.764  -2.141  -1.150  -0.659  -0.139   0.861 1.001  3300
beta[41]  -0.103   0.575  -1.239  -0.482  -0.110   0.281   1.023 1.001  5000
beta[42]  -0.176   0.720  -1.561  -0.679  -0.177   0.311   1.252 1.001  5000
beta[43]  -0.909   0.941  -2.778  -1.537  -0.913  -0.284   0.997 1.023   140
beta[44]   1.161   1.102  -0.856   0.404   1.132   1.884   3.372 1.019   170
beta[45]  -0.043   0.074  -0.190  -0.093  -0.041   0.004   0.108 1.028   110
beta[46]  -0.010   0.073  -0.153  -0.060  -0.007   0.035   0.140 1.031   100
beta[47]   0.130   1.069  -2.013  -0.579   0.108   0.861   2.204 1.006   580
beta[48]  -1.454   0.653  -2.749  -1.901  -1.441  -1.004  -0.189 1.002  1500
beta[49]   3.268   0.936   1.489   2.629   3.243   3.890   5.209 1.004   820
beta[50]  -0.927   0.531  -1.994  -1.283  -0.914  -0.556   0.074 1.002  2200
beta[51]  -0.151   1.109  -2.348  -0.906  -0.139   0.595   2.022 1.002  2200
beta[52]   0.071   1.168  -2.212  -0.692   0.080   0.855   2.372 1.001  5000
beta0     -7.353   6.605 -21.191 -11.398  -6.816  -2.952   4.771 1.040   100
deviance 100.803   8.419  85.863  94.886 100.395 106.189 118.479 1.001  5000

For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).

DIC info (using the rule, pD = var(deviance)/2)
pD = 35.5 and DIC = 136.3
DIC is an estimate of expected predictive error (lower deviance is better).
# posterior samples of coefficients
beta_samples = as.data.frame(ggs(as.mcmc(jags_model3)))

# calculate the mean for each parameter (beta)
beta_means = beta_samples %>%
  group_by(Parameter) %>%
  summarize(mean_value = mean(value)) %>%
  spread(key = Parameter, value = mean_value)

# prepare the test data with all predictors
model_test_data = as.matrix(test_data_discrete[, features])

# calculate predicted probabilities using posterior means
logit_prediction = beta_means$beta0 + 
  model_test_data %*% as.numeric(beta_means[grep("beta(?!0)", names(beta_means), perl = TRUE)])

# convert log-odds to probabilities
pred_probs = 1 / (1 + exp(- logit_prediction))

# convert probabilities to class predictions
predictions = ifelse(pred_probs >= 0.5, 1, 0)

# calculate metrics
accuracy = sum(predictions == test_data_discrete$Cath) / nrow(test_data_discrete)
recall = sum(predictions == 1 & test_data_discrete$Cath == 1) / sum(test_data_discrete$Cath == 1)
precision = sum(predictions == 1 & test_data_discrete$Cath == 1) / sum(predictions == 1)
f1_score = 2 * (precision * recall) / (precision + recall)

# print metrics
print(paste("Accuracy: ", accuracy))
[1] "Accuracy:  0.868421052631579"
print(paste("Recall: ", recall))
[1] "Recall:  0.962264150943396"
print(paste("Precision: ", precision))
[1] "Precision:  0.864406779661017"
print(paste("F1 Score: ", f1_score))
[1] "F1 Score:  0.910714285714286"

4. Model with “extreme” feature selection and feature engineering

The final model we propose is much more extreme than the first three: we keep the data discretization done before but perform a very strong feature selection operation, keeping only those features that have a reasonably high Pearson correlation score with the target.

# define threshold
threshold = 0.2

# get correlations
correlations = sapply(names(data)[names(data) != "Cath"], function(x){
  cor(data[[x]], data[["Cath"]])
})

correlations = data.frame(Feature = names(correlations), Correlation = correlations)

# select features with correlation above the threshold
features = (correlations %>% filter(abs(Correlation) > threshold))$Feature

features
 [1] "Age"                "DM"                 "HTN"               
 [4] "Typical.Chest.Pain" "Atypical"           "Nonanginal"        
 [7] "Tinversion"         "FBS"                "EF.TTE"            
[10] "Region.RWMA"       

Setting the threshold to 0.2 we get only 10 features. With respect to the previous models we can expect lower performances but much better convergence scores given the lower complexity of the model.

# pass parameters to format for JAGS
model_data = list(
  x = as.matrix(train_data_discrete[, features]),
  y = train_data_discrete$Cath,
  n_samples = nrow(train_data_discrete),
  n_features = length(features)
)

jags_model4 = jags(model.file=textConnection(jags_code),
                  data = model_data, 
                  inits = NULL,
                  n.chains = 5,
                  n.iter = 15000,
                  n.burnin = 5000,
                  parameters.to.save = c("beta0", "beta"))
Compiling model graph
   Resolving undeclared variables
   Allocating nodes
Graph information:
   Observed stochastic nodes: 227
   Unobserved stochastic nodes: 11
   Total graph size: 3438

Initializing model
jags_model4
Inference for Bugs model at "7", fit using jags,
 5 chains, each with 15000 iterations (first 5000 discarded), n.thin = 10
 n.sims = 5000 iterations saved
         mu.vect sd.vect    2.5%     25%     50%     75%   97.5%  Rhat n.eff
beta[1]    1.405   0.510   0.424   1.058   1.397   1.737   2.436 1.005   620
beta[2]    1.365   0.601   0.188   0.950   1.368   1.771   2.558 1.004   940
beta[3]    1.275   0.454   0.422   0.962   1.268   1.568   2.180 1.001  5000
beta[4]    2.927   0.636   1.703   2.506   2.918   3.333   4.203 1.001  5000
beta[5]    0.026   0.604  -1.159  -0.387   0.023   0.432   1.216 1.001  4000
beta[6]   -1.116   0.814  -2.778  -1.646  -1.112  -0.560   0.440 1.001  4300
beta[7]    1.446   0.571   0.358   1.058   1.434   1.822   2.599 1.001  3400
beta[8]    0.475   0.512  -0.542   0.124   0.481   0.821   1.482 1.004   880
beta[9]   -0.897   0.447  -1.766  -1.196  -0.892  -0.596  -0.020 1.002  2400
beta[10]   3.103   0.775   1.650   2.563   3.082   3.606   4.683 1.002  1700
beta0     -7.265   1.933 -11.280  -8.477  -7.205  -5.956  -3.599 1.003  1300
deviance 120.630   4.721 113.145 117.277 119.947 123.426 131.114 1.002  1700

For each parameter, n.eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor (at convergence, Rhat=1).

DIC info (using the rule, pD = var(deviance)/2)
pD = 11.1 and DIC = 131.8
DIC is an estimate of expected predictive error (lower deviance is better).
# posterior samples of coefficients
beta_samples = as.data.frame(ggs(as.mcmc(jags_model4)))

# calculate the mean for each parameter (beta)
beta_means = beta_samples %>%
  group_by(Parameter) %>%
  summarize(mean_value = mean(value)) %>%
  spread(key = Parameter, value = mean_value)

# prepare the test data with all predictors
model_test_data = as.matrix(test_data_discrete[, features])

# calculate predicted probabilities using posterior means
logit_prediction = beta_means$beta0 + 
  model_test_data %*% as.numeric(beta_means[grep("beta(?!0)", names(beta_means), perl = TRUE)])

# convert log-odds to probabilities
pred_probs = 1 / (1 + exp(- logit_prediction))

# convert probabilities to class predictions
predictions = ifelse(pred_probs >= 0.5, 1, 0)

# calculate metrics
accuracy = sum(predictions == test_data_discrete$Cath) / nrow(test_data_discrete)
recall = sum(predictions == 1 & test_data_discrete$Cath == 1) / sum(test_data_discrete$Cath == 1)
precision = sum(predictions == 1 & test_data_discrete$Cath == 1) / sum(predictions == 1)
f1_score = 2 * (precision * recall) / (precision + recall)

# print metrics
print(paste("Accuracy: ", accuracy))
[1] "Accuracy:  0.868421052631579"
print(paste("Recall: ", recall))
[1] "Recall:  0.943396226415094"
print(paste("Precision: ", precision))
[1] "Precision:  0.87719298245614"
print(paste("F1 Score: ", f1_score))
[1] "F1 Score:  0.909090909090909"

Models comparison

One of the main criteria used to carry out model selection between Bayesian models obtained with Monte Carlo Markov Chains is the Deviance Information Criterion (DIC). Under the assumption of a multivariate normal distribution of the parameters this criterion measures a score (the lower the better) that favors goodness of fit on the data and penalizes the complexity of the model.

Confronting the four models we can see the first two getting the same performances in terms of metrics, with the second having a lower DIC. We can see however that both models have some problems with convergence, as can be observed from some Rhat values being \(>1.1\) despite the high amount of iterations.

The model that uses feature discretization has better performances in all metrics and manages to reach a lower DIC than the first model, while having the same amount of features.

The last model, despite the tremendous decrease in the amount of features, manages to get better performances in all metrics when compared to the first two and has about the same F1-score of the third, with a slightly lower recall and slightly higher precision, and has the lowest DIC of all models.

In light of those observations the best models between the four proposed are the third and the fourth and in the next part of the report we will focus on those models to study convergence.

MCMC convergence diagnostics

What we are doing with JAGS is estimating the posterior distribution of the parameters from their priors and the data using Monte Carlo Markov Chains (MCMCs). If the MCMC has not converged the sampling will be biased, leading to inaccurate predictions. Therefore when using JAGS is critically important to check the convergence of the MCMCs.

There are many different diagnostics for the convergence of MCMCs. To have an idea of the various diagnostics possible we explored the R library ggmcmc, which contains various tools for assessing and diagnosing convergence of MCMCs and decided to use Gelman-Rubin statistic (R-hat), Effective sample size (ESS) (which are also included in the report at the end of the JAGS execution), Geweke’s diagnostic, Autocorrelation, Trace plots, Density plots.

Geweke’s diagnostic

Diagnostic for the convergence of MCMCs proposed by Geweke in 1992. It is an hypothesis test that has as null hypothesis that the Markov chain is in the stationary distribution. It is based on a test for equality of the means of the first and last part of the Markov chain. The reported value is the Z-score: the difference between the two sample means divided by its estimated standard error.

Plot for model 3:

Plot for model 4:

Gelman-Rubin statistic (R-hat)

Gelman-Rubin statistic, also known Potential Scale Reduction Factor or just R-hat statistic is a statistic to assess convergence of Monte Carlo Markov Chains proposed by Gelman and Rubin in 1992. Its value is:

\[ \hat{R} = \frac{\frac{n-1}{n}W + \frac{1}{n}B}{W} \]

Where \(n\) is the length of each chain, \(B\) is the variance between the mean of the chains and \(W\) the mean variance inside each chain. The usual threshold to assess convergence is \(\hat{R} < 1.1\).

Plot for model 3:

Plot for model 4:

As said before the first two models have some Rhat values that are above the threshold, thus leading us to conclude that they had some convergence issue. Meanwhile both the third and the fourth model have values below the threshold and in particular the fourth has all values well below it.

Effective sample size (ESS) and autocorrelation

Inside the same chain samples tend to be autocorrelated. The effective sample size is an estimate of the sample size required to achieve the same level of precision if that sample was a simple random sample.

The plots show the lag-k autocorrelation, the correlation between a sample and the sample k steps before. This value should become smaller as k increases and indicates that samples can be considered independent.

The ESS is already reported in the JAGS execution while we insert below the autocorrelation plots:

Plot for model 3:

Plot for model 4:

We can see from both the autocorrelation plots and the low effective sample size reported that for the third model the parameters beta[45], beta[46] and the intercept beta0 present some issues, with their curve still being significantly different from zero even for a large k, even if it seems continuing going down.

Meanwhile the fourth model seem present no issues at all for this diagnostic.

Trace plots

Trace plots show the behavior of each chain for each parameter over the iterations. In the trace plots, we want to avoid flat parts where the chain stays in the same state for too long or too many consecutive steps in the same direction.

Plot for model 3:

Plot for model 4:

This diagnostic confirms what we have already seen: for the third model some values seems to have not yet converged; critical behaviors are present in the intercept beta0 but also some coefficients, more than we have already diagnosed with the previous tools.

Density plots

They are the density plots of the posterior parameters distributions. Since we have multiple chains there are multiple plots superimposed. The similarity of distribution of different chains for the same parameter is a good symptom of convergence.

Plot for model 3:

Plot for model 4:

Also this diagnostic confirms what we have said above: for the third model the intercept and some coefficients have distribution that differs greatly between chains signifying issues in convergence, while in the fourth model we can observe a much more uniform behavior.

Comparative analysis with frequentist inference

As suggested by the Final Project Guidelines we want to perform a confrontation between our models and what we could obtain with a frequentist model. This can be of interest to have empirical proof of the effectiveness of the Bayesian approach.

features = colnames(data)[colnames(data) != "Cath"]

model_data = train_data[, features]

# fit the logistic regression
logistic_model = glm(Cath ~ Age + Weight + Length + Sex + BMI + DM + HTN +
                            Current.Smoker + EX.Smoker + FH + Obesity + CRF +
                            CVA + Airway.disease + Thyroid.Disease + CHF + DLP +
                            BP + PR + Edema + Weak.Peripheral.Pulse + Lung.rales +
                            Systolic.Murmur + Diastolic.Murmur + Typical.Chest.Pain +
                            Dyspnea + Function.Class + Atypical + Nonanginal +
                            LowTH.Ang + Q.Wave + St.Elevation + St.Depression +
                            Tinversion + LVH + Poor.R.Progression + FBS + CR +
                            TG + LDL + HDL + BUN + ESR + HB + K + Na + WBC +
                            Lymph + Neut + PLT + EF.TTE + Region.RWMA + VHD +
                            LBBB + RBBB,
                      data = train_data)

summary(logistic_model)

Call:
glm(formula = Cath ~ Age + Weight + Length + Sex + BMI + DM + 
    HTN + Current.Smoker + EX.Smoker + FH + Obesity + CRF + CVA + 
    Airway.disease + Thyroid.Disease + CHF + DLP + BP + PR + 
    Edema + Weak.Peripheral.Pulse + Lung.rales + Systolic.Murmur + 
    Diastolic.Murmur + Typical.Chest.Pain + Dyspnea + Function.Class + 
    Atypical + Nonanginal + LowTH.Ang + Q.Wave + St.Elevation + 
    St.Depression + Tinversion + LVH + Poor.R.Progression + FBS + 
    CR + TG + LDL + HDL + BUN + ESR + HB + K + Na + WBC + Lymph + 
    Neut + PLT + EF.TTE + Region.RWMA + VHD + LBBB + RBBB, data = train_data)

Coefficients:
                        Estimate Std. Error t value Pr(>|t|)    
(Intercept)            3.746e+00  3.109e+00   1.205  0.23004    
Age                    1.101e-02  2.727e-03   4.038 8.12e-05 ***
Weight                 2.405e-02  1.909e-02   1.260  0.20926    
Length                -2.477e-02  1.772e-02  -1.398  0.16387    
Sex                    1.048e-01  8.037e-02   1.304  0.19404    
BMI                   -7.557e-02  5.038e-02  -1.500  0.13550    
DM                     2.073e-01  7.236e-02   2.865  0.00469 ** 
HTN                    9.518e-02  6.475e-02   1.470  0.14344    
Current.Smoker         9.230e-02  6.804e-02   1.357  0.17670    
EX.Smoker              5.140e-02  1.555e-01   0.331  0.74135    
FH                     1.576e-01  7.188e-02   2.193  0.02969 *  
Obesity                3.448e-02  7.600e-02   0.454  0.65062    
CRF                   -9.709e-02  1.743e-01  -0.557  0.57829    
CVA                    2.075e-02  1.677e-01   0.124  0.90164    
Airway.disease         1.343e-01  1.203e-01   1.116  0.26594    
Thyroid.Disease        1.038e-03  1.569e-01   0.007  0.99473    
CHF                    2.907e-01  4.098e-01   0.709  0.47909    
DLP                   -2.869e-02  4.985e-02  -0.575  0.56572    
BP                     8.179e-04  1.576e-03   0.519  0.60441    
PR                     6.578e-03  3.182e-03   2.067  0.04023 *  
Edema                 -7.337e-02  1.218e-01  -0.602  0.54766    
Weak.Peripheral.Pulse -9.145e-03  1.991e-01  -0.046  0.96341    
Lung.rales             3.139e-01  1.519e-01   2.066  0.04030 *  
Systolic.Murmur        2.358e-02  8.729e-02   0.270  0.78736    
Diastolic.Murmur      -1.349e-01  1.521e-01  -0.887  0.37648    
Typical.Chest.Pain     2.595e-01  8.304e-02   3.125  0.00209 ** 
Dyspnea               -6.490e-02  5.758e-02  -1.127  0.26130    
Function.Class         4.563e-02  2.649e-02   1.723  0.08678 .  
Atypical              -9.878e-02  8.901e-02  -1.110  0.26863    
Nonanginal            -1.839e-01  1.249e-01  -1.473  0.14254    
LowTH.Ang             -1.958e-01  3.737e-01  -0.524  0.60100    
Q.Wave                -9.562e-02  1.448e-01  -0.660  0.50994    
St.Elevation           1.833e-01  1.478e-01   1.240  0.21652    
St.Depression         -3.279e-03  6.757e-02  -0.049  0.96135    
Tinversion             1.028e-01  6.233e-02   1.649  0.10102    
LVH                    6.016e-02  1.164e-01   0.517  0.60610    
Poor.R.Progression     6.398e-02  1.593e-01   0.402  0.68838    
FBS                    2.994e-05  6.210e-04   0.048  0.96160    
CR                     2.895e-02  1.131e-01   0.256  0.79830    
TG                     5.121e-04  2.892e-04   1.771  0.07840 .  
LDL                   -4.419e-04  7.523e-04  -0.587  0.55774    
HDL                   -9.285e-04  2.271e-03  -0.409  0.68317    
BUN                   -7.088e-03  4.086e-03  -1.735  0.08458 .  
ESR                    1.342e-03  1.845e-03   0.728  0.46781    
HB                    -2.543e-03  1.850e-02  -0.137  0.89080    
K                      1.047e-01  5.466e-02   1.915  0.05716 .  
Na                    -5.648e-03  6.825e-03  -0.828  0.40905    
WBC                   -8.267e-06  1.321e-05  -0.626  0.53227    
Lymph                  8.307e-04  6.293e-03   0.132  0.89513    
Neut                   2.149e-03  6.157e-03   0.349  0.72748    
PLT                    1.966e-05  4.221e-04   0.047  0.96290    
EF.TTE                 1.799e-03  3.594e-03   0.501  0.61732    
Region.RWMA            4.716e-02  2.526e-02   1.867  0.06361 .  
VHD                   -7.498e-02  4.167e-02  -1.799  0.07372 .  
LBBB                  -1.318e-01  1.249e-01  -1.055  0.29307    
RBBB                   2.637e-02  1.256e-01   0.210  0.83390    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

(Dispersion parameter for gaussian family taken to be 0.1009441)

    Null deviance: 45.956  on 226  degrees of freedom
Residual deviance: 17.261  on 171  degrees of freedom
AIC: 173.34

Number of Fisher Scoring iterations: 2
# make inference on test data

pred_values = predict(logistic_model, test_data, type = "response")
predictions = ifelse(pred_values > 0.5, 1, 0)

# calculate metrics
accuracy = mean(predictions == test_data$Cath)
recall = sum(predictions == 1 & test_data$Cath == 1) / sum(test_data$Cath == 1)
precision = sum(predictions == 1 & test_data$Cath == 1) / sum(predictions == 1)
f1_score = 2 * (precision * recall) / (precision + recall)

# print metrics
print(paste("Accuracy: ", accuracy))
[1] "Accuracy:  0.828947368421053"
print(paste("Recall: ", recall))
[1] "Recall:  0.830188679245283"
print(paste("Precision: ", precision))
[1] "Precision:  0.916666666666667"
print(paste("F1 Score: ", f1_score))
[1] "F1 Score:  0.871287128712871"

We can observe that the Frequentist approach is overall less performant than the Bayesian one. We can also see that it gets better results in terms of precision, while as said in the “Goals” section of the report we preferred models with a high recall instead.

Conclusions

We successfully used JAGS for estimation of the model parameters and created two models with satisfying performances.

The third model is the most performant but suffers from some convergence issues which make it less reliable, it is possible that those problems are linked to the high amount of features of the model and could be solved by using a higher amount of iterations, but at the cost of an increased computational time.

Meanwhile the fourth model is much lighter and quickly reached convergence, at the cost of being restricted to a simpler model with slightly lower performances.

References

Alizadehsani, R., Roshanzamir, M., & Sani, Z. (2013). Z-Alizadeh Sani [Dataset]. UCI Machine Learning Repository. https://doi.org/10.24432/C5Q31T.

Alizadehsani, R. et al. A data mining approach for diagnosis of coronary artery disease. Comput. Methods Programs Biomed. 111, 52–61 (2013).

National Heart, Lung and Blood Institute. What is Coronary Heart Disease? https://www.nhlbi.nih.gov/health/coronary-heart-disease

Department of Health, New York State. Heart Disease and Stroke Prevention. https://www.health.ny.gov/diseases/cardiovascular/heart_disease/